{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d8d7c2a-0bdb-4505-bd8d-4486e4e72c8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "##0) Install + Login + Load LLaMA-3.1-8B-Instruct\n",
    "\n",
    "!pip -q install -U transformers accelerate bitsandbytes sentencepiece huggingface_hub openpyxl tqdm\n",
    "python\n",
    "Copy code\n",
    "\n",
    "from huggingface_hub import login\n",
    "login()\n",
    "\n",
    "import torch\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig\n",
    "\n",
    "MODEL_ID = \"meta-llama/Llama-3.1-8B-Instruct\"\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    MODEL_ID,\n",
    "    device_map=\"auto\",\n",
    "    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,\n",
    "    load_in_4bit=True,\n",
    ")\n",
    "\n",
    "# Conservative generation settings for classification stability\n",
    "GEN_CFG = GenerationConfig(\n",
    "    max_new_tokens=220,\n",
    "    temperature=0.2,\n",
    "    top_p=0.9,\n",
    "    do_sample=True,\n",
    ")\n",
    "\n",
    "## 1) Prompt Template + JSON Extraction\n",
    "\n",
    "import json\n",
    "import re\n",
    "\n",
    "PROMPT_TEMPLATE = \"\"\"\n",
    "You are a social science researcher coding a Korean news article for a study on nuclear risk interpretation.\n",
    "\n",
    "Important: This task is NOT about whether the article is favorable or critical toward nuclear energy.\n",
    "Instead, classify the dominant interpretive logic through which nuclear risk is narrated.\n",
    "\n",
    "Definitions:\n",
    "\n",
    "Crisis-based framing (code = 0):\n",
    "- Nuclear risk is presented as an episodic event or policy crisis.\n",
    "- Emphasis on accidents, safety lapses, regulatory failures, or government response.\n",
    "- Focus on elite responsibility, institutional performance, and technical or policy remedies.\n",
    "- Temporal orientation is short-term and reactive.\n",
    "\n",
    "Grievance-based framing (code = 1):\n",
    "- Nuclear risk is presented as an ongoing condition affecting local communities.\n",
    "- Emphasis on cumulative exposure, unequal burden-sharing, and lived experience.\n",
    "- Focus on local residents, everyday vulnerability, and long-term consequences.\n",
    "- Temporal orientation is continuous and cumulative.\n",
    "\n",
    "Task:\n",
    "1) Decide which framing is more dominant in the article: crisis-based (0) or grievance-based (1).\n",
    "2) Provide up to two short excerpts (maximum 25 words each) that justify your decision.\n",
    "3) Provide a confidence score from 0 to 100.\n",
    "\n",
    "Return output in JSON only (no other text):\n",
    "{{\n",
    "  \"Y\": 0 or 1,\n",
    "  \"frame_label\": \"crisis\" or \"grievance\",\n",
    "  \"evidence\": [\"...\", \"...\"],\n",
    "  \"confidence\": 0-100\n",
    "}}\n",
    "\n",
    "Article text:\n",
    "<<<\n",
    "{ARTICLE_TEXT}\n",
    ">>>\n",
    "\"\"\".strip()\n",
    "\n",
    "\n",
    "def extract_json_object(text: str):\n",
    "    \"\"\"\n",
    "    Robustly extract the first JSON object found in a text block.\n",
    "    \"\"\"\n",
    "    m = re.search(r\"\\{.*\\}\", text, flags=re.S)\n",
    "    if not m:\n",
    "        return None\n",
    "    try:\n",
    "        return json.loads(m.group(0))\n",
    "    except json.JSONDecodeError:\n",
    "        return None\n",
    "\n",
    "## 2) Token-Based Segmentation\n",
    "\n",
    "def chunk_text_by_tokens(text: str, max_input_tokens: int = 900, overlap_tokens: int = 80):\n",
    "    \"\"\"\n",
    "    Token-based chunking for long news articles.\n",
    "    \"\"\"\n",
    "    if not isinstance(text, str) or not text.strip():\n",
    "        return []\n",
    "\n",
    "    token_ids = tokenizer.encode(text, add_special_tokens=False)\n",
    "    chunks = []\n",
    "\n",
    "    start = 0\n",
    "    n = len(token_ids)\n",
    "    while start < n:\n",
    "        end = min(start + max_input_tokens, n)\n",
    "        chunk_ids = token_ids[start:end]\n",
    "        chunk_text = tokenizer.decode(chunk_ids, skip_special_tokens=True)\n",
    "        chunks.append(chunk_text)\n",
    "\n",
    "        if end == n:\n",
    "            break\n",
    "        start = max(0, end - overlap_tokens)\n",
    "\n",
    "    return chunks\n",
    "\n",
    "## 3) Segment-Level Classification Function\n",
    "\n",
    "def classify_segment(seg_text: str):\n",
    "    \"\"\"\n",
    "    Run one segment through the model and return:\n",
    "    - raw completion text\n",
    "    - parsed JSON dict (or None)\n",
    "    \"\"\"\n",
    "    prompt = PROMPT_TEMPLATE.format(ARTICLE_TEXT=seg_text)\n",
    "    inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        out_ids = model.generate(**inputs, generation_config=GEN_CFG)\n",
    "\n",
    "    decoded = tokenizer.decode(out_ids[0], skip_special_tokens=True)\n",
    "\n",
    "    # If the model echoed the prompt, remove it\n",
    "    completion = decoded[len(prompt):].strip() if decoded.startswith(prompt) else decoded.strip()\n",
    "\n",
    "    parsed = extract_json_object(completion)\n",
    "    return completion, parsed\n",
    "\n",
    "## 4) Modal Aggregation + Confidence Tie-Break (Article-Level)\n",
    "\n",
    "import numpy as np\n",
    "from collections import Counter, defaultdict\n",
    "\n",
    "def _safe_conf(x):\n",
    "    try:\n",
    "        return float(x)\n",
    "    except Exception:\n",
    "        return 0.0\n",
    "\n",
    "def modal_with_confidence_tiebreak(values, confidences):\n",
    "    counts = Counter(values)\n",
    "    top_count = max(counts.values())\n",
    "    candidates = [k for k, v in counts.items() if v == top_count]\n",
    "    if len(candidates) == 1:\n",
    "        return candidates[0]\n",
    "\n",
    "    conf_sum = defaultdict(float)\n",
    "    for v, c in zip(values, confidences):\n",
    "        conf_sum[v] += c\n",
    "    candidates = sorted(candidates, key=lambda x: conf_sum[x], reverse=True)\n",
    "    return candidates[0]\n",
    "\n",
    "def aggregate_article(segment_outputs):\n",
    "    \"\"\"\n",
    "    Aggregate segment-level outputs to an article-level label.\n",
    "\n",
    "    segment_outputs: list of dicts with keys:\n",
    "      - \"parsed\": dict or None\n",
    "      - \"raw\": raw output string\n",
    "    \"\"\"\n",
    "    valid = [s for s in segment_outputs if s.get(\"parsed\") is not None]\n",
    "    if not valid:\n",
    "        return {\n",
    "            \"Y\": None,\n",
    "            \"frame_label\": None,\n",
    "            \"confidence\": None,\n",
    "            \"evidence\": None,\n",
    "            \"n_segments\": len(segment_outputs),\n",
    "            \"n_valid_segments\": 0\n",
    "        }\n",
    "\n",
    "    Ys = []\n",
    "    confs = []\n",
    "    for v in valid:\n",
    "        y = v[\"parsed\"].get(\"Y\")\n",
    "        # Accept Y only if 0/1\n",
    "        if y not in [0, 1]:\n",
    "            continue\n",
    "        Ys.append(y)\n",
    "        confs.append(_safe_conf(v[\"parsed\"].get(\"confidence\")))\n",
    "\n",
    "    if not Ys:\n",
    "        return {\n",
    "            \"Y\": None,\n",
    "            \"frame_label\": None,\n",
    "            \"confidence\": None,\n",
    "            \"evidence\": None,\n",
    "            \"n_segments\": len(segment_outputs),\n",
    "            \"n_valid_segments\": 0\n",
    "        }\n",
    "\n",
    "    Y_hat = modal_with_confidence_tiebreak(Ys, confs)\n",
    "\n",
    "    # Article-level confidence: mean of confidences among segments that match Y_hat\n",
    "    matched_confs = [c for (y, c) in zip(Ys, confs) if y == Y_hat]\n",
    "    article_conf = float(np.mean(matched_confs)) if matched_confs else float(np.mean(confs))\n",
    "\n",
    "    # Evidence: take from highest-confidence segments that match Y_hat\n",
    "    evidence_pool = []\n",
    "    for v in valid:\n",
    "        y = v[\"parsed\"].get(\"Y\")\n",
    "        c = _safe_conf(v[\"parsed\"].get(\"confidence\"))\n",
    "        ev = v[\"parsed\"].get(\"evidence\")\n",
    "        if y == Y_hat and isinstance(ev, list) and len(ev) > 0:\n",
    "            evidence_pool.append((c, ev))\n",
    "\n",
    "    evidence_pool.sort(key=lambda x: x[0], reverse=True)\n",
    "\n",
    "    chosen = []\n",
    "    for _, ev_list in evidence_pool[:3]:\n",
    "        for e in ev_list:\n",
    "            if isinstance(e, str) and e.strip():\n",
    "                chosen.append(e.strip())\n",
    "        if len(chosen) >= 2:\n",
    "            break\n",
    "\n",
    "    chosen = chosen[:2] if chosen else None\n",
    "\n",
    "    frame_label = \"grievance\" if Y_hat == 1 else \"crisis\"\n",
    "\n",
    "    return {\n",
    "        \"Y\": Y_hat,\n",
    "        \"frame_label\": frame_label,\n",
    "        \"confidence\": round(article_conf, 2),\n",
    "        \"evidence\": chosen,\n",
    "        \"n_segments\": len(segment_outputs),\n",
    "        \"n_valid_segments\": len(valid)\n",
    "    }\n",
    "\n",
    "## 5) Excel Upload → Batch Inference → Excel Save (Colab)\n",
    "\n",
    "from tqdm import tqdm\n",
    "from google.colab import files\n",
    "\n",
    "# --- Upload Excel ---\n",
    "uploaded = files.upload()\n",
    "in_file = list(uploaded.keys())[0]\n",
    "\n",
    "df = pd.read_excel(in_file)\n",
    "\n",
    "# --- Set your text column name here ---\n",
    "TEXT_COL = \"content\"  # change to \"contents\" if your file uses that name\n",
    "df[TEXT_COL] = df[TEXT_COL].astype(str)\n",
    "\n",
    "# Chunking parameters\n",
    "MAX_INPUT_TOKENS = 900\n",
    "OVERLAP_TOKENS = 80\n",
    "\n",
    "results = []\n",
    "\n",
    "for i in tqdm(range(len(df))):\n",
    "    text = df.loc[i, TEXT_COL]\n",
    "\n",
    "    chunks = chunk_text_by_tokens(\n",
    "        text,\n",
    "        max_input_tokens=MAX_INPUT_TOKENS,\n",
    "        overlap_tokens=OVERLAP_TOKENS\n",
    "    )\n",
    "\n",
    "    seg_outputs = []\n",
    "    for seg in chunks:\n",
    "        raw, parsed = classify_segment(seg)\n",
    "        seg_outputs.append({\"raw\": raw, \"parsed\": parsed})\n",
    "\n",
    "    agg = aggregate_article(seg_outputs)\n",
    "\n",
    "    results.append({\n",
    "        \"Y\": agg[\"Y\"],  # 1=grievance, 0=crisis\n",
    "        \"frame_label\": agg[\"frame_label\"],\n",
    "        \"confidence\": agg[\"confidence\"],\n",
    "        \"evidence_1\": (agg[\"evidence\"][0] if agg[\"evidence\"] else None),\n",
    "        \"evidence_2\": (agg[\"evidence\"][1] if agg[\"evidence\"] else None),\n",
    "        \"n_segments\": agg[\"n_segments\"],\n",
    "        \"n_valid_segments\": agg[\"n_valid_segments\"]\n",
    "    })\n",
    "\n",
    "df_out = pd.concat([df.reset_index(drop=True), pd.DataFrame(results)], axis=1)\n",
    "\n",
    "# --- Save to Excel ---\n",
    "out_file = \"classified_framing_output.xlsx\"\n",
    "df_out.to_excel(out_file, index=False)\n",
    "\n",
    "files.download(out_file)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
